共计 7904 个字符,预计需要花费 20 分钟才能阅读完成。
提醒:本文最后更新于 2024-08-30 15:39,文中所关联的信息可能已发生改变,请知悉!
导入库
from __future__ import print_function
import matplotlib.pyplot as plt
%matplotlib inline
import argparse
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import numpy as np
from models import *
import torch
import torch.optim
from utils.feature_inversion_utils import *
from utils.perceptual_loss.perceptual_loss import get_pretrained_net
from utils.common_utils import *
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark =True
dtype = torch.cuda.FloatTensor
PLOT = True
fname = './data/feature_inversion/building.jpg'
pretrained_net = 'alexnet_caffe' # 'alexnet_caffe' # 'vgg19_caffe'
layers_to_use = 'fc6' # comma-separated string of layer names e.g. 'fc6,fc7'
载入预训练网络
# cnn = get_pretrained_net(pretrained_net).type(dtype)
cnn = torch.load('./data/feature_inversion/alexnet-torch_py3.pth').type(dtype)
opt_content = {'layers': layers_to_use, 'what':'features'}
# Remove the layers we don't need
keys = [x for x in cnn._modules.keys()]
max_idx = max(keys.index(x) for x in opt_content['layers'].split(','))
for k in keys[max_idx+1:]:
cnn._modules.pop(k)
print(cnn)
其中,源码第一行
cnn = get_pretrained_net(pretrained_net).type(dtype)
中的 get_pretrained_net
函数如下:
def get_pretrained_net(name):
"""Loads pretrained network"""
if name == 'alexnet_caffe':
if not os.path.exists('alexnet-torch_py3.pth'):
print('Downloading AlexNet')
os.system('wget -O alexnet-torch_py3.pth --no-check-certificate -nc https://box.skoltech.ru/index.php/s/77xSWvrDN0CiQtK/download')
# os.system('wget -O alexnet-torch_py3.pth --no-check-certificate -nc https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth')
return torch.load('alexnet-torch_py3.pth')
elif name == 'vgg19_caffe':
if not os.path.exists('vgg19-caffe-py3.pth'):
print('Downloading VGG-19')
os.system('wget -O vgg19-caffe-py3.pth --no-check-certificate -nc https://box.skoltech.ru/index.php/s/HPcOFQTjXxbmp4X/download')
vgg = get_vgg19_caffe()
return vgg
elif name == 'vgg16_caffe':
if not os.path.exists('vgg16-caffe-py3.pth'):
print('Downloading VGG-16')
os.system('wget -O vgg16-caffe-py3.pth --no-check-certificate -nc https://box.skoltech.ru/index.php/s/TUZ62HnPKWdxyLr/download')
vgg = get_vgg16_caffe()
return vgg
elif name == 'vgg19_pytorch_modified':
# os.system('wget -O data/feature_inversion/vgg19-caffe.pth --no-check-certificate -nc https://www.dropbox.com/s/xlbdo688dy4keyk/vgg19-caffe.pth?dl=1')
model = VGGModified(vgg19(pretrained=False), 0.2)
model.load_state_dict(torch.load('vgg_pytorch_modified.pkl')['state_dict'])
return model
else:
assert False
其中,例如第 6 行及其类似语句
os.system('wget -O alexnet-torch_py3.pth --no-check-certificate -nc https://box.skoltech.ru/index.php/s/77xSWvrDN0CiQtK/download')
无法正确执行,所以我弃用这个函数,改为本地导入,如下:
cnn = torch.load('./data/feature_inversion/alexnet-torch_py3.pth').type(dtype)
最后能正确执行
E:\anaconda3\lib\site-packages\torch\serialization.py:671: SourceChangeWarning: source code of class 'torch.nn.modules.container.Sequential' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
warnings.warn(msg, SourceChangeWarning)
E:\anaconda3\lib\site-packages\torch\serialization.py:671: SourceChangeWarning: source code of class 'torch.nn.modules.conv.Conv2d' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
warnings.warn(msg, SourceChangeWarning)
E:\anaconda3\lib\site-packages\torch\serialization.py:671: SourceChangeWarning: source code of class 'torch.nn.modules.activation.ReLU' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
warnings.warn(msg, SourceChangeWarning)
E:\anaconda3\lib\site-packages\torch\serialization.py:671: SourceChangeWarning: source code of class 'torch.nn.modules.normalization.CrossMapLRN2d' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
warnings.warn(msg, SourceChangeWarning)
E:\anaconda3\lib\site-packages\torch\serialization.py:671: SourceChangeWarning: source code of class 'torch.nn.modules.pooling.MaxPool2d' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
warnings.warn(msg, SourceChangeWarning)
E:\anaconda3\lib\site-packages\torch\serialization.py:671: SourceChangeWarning: source code of class 'torch.nn.modules.linear.Linear' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
warnings.warn(msg, SourceChangeWarning)
E:\anaconda3\lib\site-packages\torch\serialization.py:671: SourceChangeWarning: source code of class 'torch.nn.modules.dropout.Dropout' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
warnings.warn(msg, SourceChangeWarning)
Sequential((conv1): Conv2d(3, 96, kernel_size=(11, 11), stride=(4, 4))
(relu1): ReLU()
(norm1): CrossMapLRN2d(5, alpha=0.0001, beta=0.75, k=1)
(pool1): MaxPool2d(kernel_size=(3, 3), stride=(2, 2), padding=(0, 0), dilation=1, ceil_mode=True)
(conv2): Conv2d(96, 256, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=2)
(relu2): ReLU()
(norm2): CrossMapLRN2d(5, alpha=0.0001, beta=0.75, k=1)
(pool2): MaxPool2d(kernel_size=(3, 3), stride=(2, 2), padding=(0, 0), dilation=1, ceil_mode=True)
(conv3): Conv2d(256, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(relu3): ReLU()
(conv4): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=2)
(relu4): ReLU()
(conv5): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=2)
(relu5): ReLU()
(pool5): MaxPool2d(kernel_size=(3, 3), stride=(2, 2), padding=(0, 0), dilation=1, ceil_mode=True)
(torch_view): View()
(fc6): Linear(in_features=9216, out_features=4096, bias=True)
)
载入图片
# Target imsize
imsize = 227 if pretrained_net == 'alexnet' else 224
# Something divisible by a power of two
imsize_net = 256
# VGG and Alexnet need input to be correctly normalized
preprocess, deprocess = get_preprocessor(imsize), get_deprocessor()
img_content_pil, img_content_np = get_image(fname, imsize)
img_content_prerocessed = preprocess(img_content_pil)[None,:].type(dtype)
img_content_pil
设置匹配器和网络
matcher_content = get_matcher(cnn, opt_content)
matcher_content.mode = 'store'
cnn(img_content_prerocessed);
INPUT = 'noise'
pad = 'zero' # 'refection'
OPT_OVER = 'net' #'net,input'
OPTIMIZER = 'adam' # 'LBFGS'
LR = 0.001
num_iter = 3100
input_depth = 32
net_input = get_noise(input_depth, INPUT, imsize_net).type(dtype).detach()
net = skip(input_depth, 3, num_channels_down = [16, 32, 64, 128, 128, 128],
num_channels_up = [16, 32, 64, 128, 128, 128],
num_channels_skip = [4, 4, 4, 4, 4, 4],
filter_size_down = [7, 7, 5, 5, 3, 3], filter_size_up = [7, 7, 5, 5, 3, 3],
upsample_mode='nearest', downsample_mode='avg',
need_sigmoid=True, pad=pad, act_fun='LeakyReLU').type(dtype)
# Compute number of parameters
s = sum(np.prod(list(p.size())) for p in net.parameters())
print ('Number of params: %d' % s)
迭代
def closure():
global i
out = net(net_input)[:, :, :imsize, :imsize]
cnn(vgg_preprocess_var(out))
total_loss = sum(matcher_content.losses.values())
total_loss.backward()
print ('Iteration %05d Loss %.3f' % (i, total_loss.item()), '\r', end='')
if PLOT and i % 200 == 0:
out_np = np.clip(torch_to_np(out), 0, 1)
plot_image_grid([out_np], 3, 3);
i += 1
return total_loss
i=0
matcher_content.mode = 'match'
p = get_params(OPT_OVER, net, net_input)
optimize(OPTIMIZER, p, closure, LR, num_iter)
由于计算机性能问题,只能跑一段有限的时间,得到下面结果,loss = 0.112
正文完
发表至: 图像处理
2022-07-31